"""
## LaVCa step 4: Derive concise voxel captions by extracting and filtering keywords from the image captions, 
## then feeding these keywords into a ``Sentence Composer.''


## For main analysis
python -m LaVCa.voxel_caption \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model gpt-4o-2024-08-06 \
    --correct_model default \
    --candidate_num 50 \
    --key_num 5 \
    --temperature 0.05 \
    --filter_th 0.15 \
    --device cuda

## For ablation study

# Concat-N (for Top-1, set candidate_num = 1)
python -m LaVCa.voxel_caption \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model default \
    --correct_model default \
    --candidate_num 10 \
    --key_num -1 \
    --temperature -1 \
    --filter_th -1 \
    --device cuda
"""

import torch
import argparse
import os
import json
from tqdm import tqdm
from utils.utils import (
    make_filename, create_and_check_temp_file, check_existing_keywords_file, 
    find_existing_keywords_file_if_needed, load_top_captions, save_top_captions_to_json, 
    get_layer_info_and_weight, load_reducer_projector, load_keyword_pipeline
)
from LaVCa.llm_utils import summarizer, load_lm_model, load_vl_model, load_parser_model_and_tokenizer
import numpy as np
from nsd.nsd_access import NSDAccess
from sentence_transformers import SentenceTransformer
from MeaCap.utils.generate_utils_ import Get_shuffle_score
from MeaCap.utils.detect_utils import retrieve_concepts
from MeaCap.args import get_class_args
from nltk.corpus import stopwords

torch.manual_seed(42)



def handle_concat_n(
    captions_list, 
    keywords_file_path, 
    temp_file_path, 
    args
):
    """
    Explanation:
    Handles the special case for Concat-N ablation. If key_num = -1 and 
    caption_model=default, just concatenate top candidate_num captions.
    """
    if int(args.key_num) == -1 and args.caption_model == "default":
        if os.path.exists(keywords_file_path):
            try:
                with open(keywords_file_path, "r") as f:
                    _ = json.load(f)
                print(f"Already processed: {keywords_file_path}")
                os.remove(temp_file_path)
                return True  # Means we skip further
            except:
                pass
        
        concat_text = ". \n".join(captions_list[:args.candidate_num])
        keys_and_sentences = {
            "keywords": None,
            "text": concat_text
        }
        with open(keywords_file_path, 'w', encoding='utf-8') as file:
            json.dump(keys_and_sentences, file, ensure_ascii=False, indent=4)
        print(f"Saved: {keywords_file_path}")
        os.remove(temp_file_path)
        return True

    return False


def extract_keywords(
    args,
    captions_list,
    voxel_index,
    weight_index_map,
    layer_weight,
    parser_model,
    parser_tokenizer,
    wte_model,
    stop_words,
    pipe,
    keys_and_text
):
    """
    Explanation:
    Depending on the keywords_model, extracts or loads keywords.
    If using a parser model, calls retrieve_concepts;
    If using an LLM, calls summarizer or uses existing keys_and_text.
    Returns (masked_sentences, voxel_weight).
    """
    sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    # Default parser-based extraction
    if args.keywords_model == "default":
        sentences = retrieve_concepts(
            parser_model=parser_model, 
            parser_tokenizer=parser_tokenizer, 
            wte_model=wte_model,
            select_memory_captions=captions_list,
            image_embeds=None,
            device=args.device, 
            logger=None
        )
        masked_sentences = sentences
        weight_index = weight_index_map[voxel_index]
        voxel_weight = torch.tensor(layer_weight[:, weight_index]).unsqueeze(0).to(args.device)
        return masked_sentences, voxel_weight

    # LLM-based concept extraction
    elif args.keywords_model == "gpt-4o-2024-08-06" or args.keywords_model == "Llama-3.1-70B-Instruct":
        sentences = keys_and_text.get("keywords", [])
        if not sentences:  # If not already loaded from file
            keywords_num = int(args.key_num)
            # Use the summarizer to generate keywords
            sentences = summarizer(
                args.keywords_model, 
                args.modality, 
                captions_list, 
                keywords_num, 
                stop_words, 
                verbose=False, 
                model=pipe
            )
        
        weight_index = weight_index_map[voxel_index]
        voxel_weight = torch.tensor(layer_weight[:, weight_index]).unsqueeze(0).to(args.device)
        return sentences, voxel_weight

    # If none of the above, return empty
    return [], None


def filter_keywords_by_similarity(
    masked_sentences,
    voxel_weight,
    vl_model,
    args
):
    """
    Explanation:
    Filters extracted keywords based on their similarity to the voxel weights.
    If filter_th <= 0, returns masked_sentences as is.
    Otherwise, uses a softmax on the similarity scores to filter out low-scoring keywords.
    """
    sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    if args.filter_th <= 0:
        return masked_sentences

    masked_sentences_cap = [
        "A picture of " + sentence + "." for sentence in masked_sentences
    ]
    masked_sentences_embs = vl_model.compute_text_representation(
        masked_sentences_cap
    )
    # Normalize
    voxel_weight = voxel_weight / voxel_weight.norm(p=2)

    sim_scores = sim_func(masked_sentences_embs, voxel_weight)
    sim_scores = torch.nn.functional.softmax(sim_scores / args.temperature, dim=0)

    filtered_masked_sentences = []
    filtered_sim_scores = []
    for i, sim_score in enumerate(sim_scores):
        if sim_score > args.filter_th:
            filtered_masked_sentences.append(masked_sentences[i])
            filtered_sim_scores.append(sim_score)

    # If everything was filtered out, revert to full list
    if not filtered_masked_sentences:
        filtered_masked_sentences = masked_sentences

    # Sort by similarity score descending
    sorted_masked_sentences = [
        sentence for _, sentence in sorted(
            zip(filtered_sim_scores, filtered_masked_sentences), 
            reverse=True
        )
    ]
    return sorted_masked_sentences


def generate_and_postprocess_caption(
    voxel_weight,
    masked_sentences,
    captions_list,
    wte_model,
    tokenizer,
    lm_model,
    vl_model,
    stop_tokens_tensor,
    sub_tokens_tensor,
    reducer_projector,
    args,
    weight_index_map,
    voxel_index
):
    """
    Explanation:
    Given the voxel weight, masked/filtered keywords, and the loaded LM model,
    generate a final caption using MeaCap (or a similar approach). Also does
    minor text post-processing.
    Returns the (best_text, input_sentences).
    """
    top_embs_use = 10  # Hardcoded
    top_captions = captions_list[:top_embs_use]
    select_text_wte_embeddings = torch.tensor(
        wte_model.encode(top_captions), dtype=torch.float32
    ).to(args.device)

    meacap_args = get_class_args()
    meacap_args.prompt_len = 0
    meacap_args.best_layer = weight_index_map[voxel_index]  # not always used
    meacap_args.device = args.device
    meacap_args.dimreducer = reducer_projector
    meacap_args.correct_model = args.correct_model
    meacap_args.target_model = None
    meacap_args.target_tokenizer = None

    gen_text = Get_shuffle_score(
        voxel_weight, masked_sentences, lm_model, vl_model, wte_model, tokenizer,
        select_text_wte_embeddings, stop_tokens_tensor, sub_tokens_tensor,
        None, None, meacap_args, args.device
    )

    # Post-processing
    # Ensure we end with a single sentence
    if '.' in gen_text[0]:
        gen_text[0] = gen_text[0].split('.')[0] + '.'
    else:
        gen_text[0] = gen_text[0] + '.'
    gen_text[0] = gen_text[0].lstrip(' ')
    gen_text[0] = gen_text[0].lower().capitalize()
    best_text = gen_text[0].replace("{", "").replace("}", "")

    return best_text


def save_final_json(
    keywords_file_path,
    best_text,
    sentences,
    masked_sentences,
    temp_file_path,
    args
):
    """
    Explanation:
    Saves final JSON results (keywords and final text). Also checks if 
    a valid JSON file already exists. Cleans up the temp file afterward.
    """
    if os.path.exists(keywords_file_path):
        try:
            with open(keywords_file_path, 'r') as f:
                keys_and_text_check = json.load(f)
                text_check = keys_and_text_check.get("text", "")
            if text_check != "":
                print(f"Already processed: {keywords_file_path}")
                os.remove(temp_file_path)
                return
        except:
            pass

    if args.keywords_model == "default":
        keys_and_sentences = {
            "keywords": sentences,
            "text": best_text
        }
    else:
        keys_and_sentences = {
            "keywords": sentences,
            "filtered_keywords": masked_sentences,
            "text": best_text
        }

    with open(keywords_file_path, 'w', encoding='utf-8') as file:
        json.dump(keys_and_sentences, file, ensure_ascii=False, indent=4)

    os.remove(temp_file_path)


def process_voxel(
    voxel_index: int,
    args,
    volume_index,
    weight_index_map,
    layer_weight,
    wte_model,
    vl_model,
    parser_model,
    parser_tokenizer,
    pipe,
    stop_words,
    tokenizer,
    lm_model,
    stop_tokens_tensor,
    sub_tokens_tensor,
    reducer_projector
):
    """
    Explanation:
    This function processes a single voxel end-to-end by:
      1) Setting up directories and a temporary file,
      2) Checking if the voxel has already been processed,
      3) Loading top captions and optionally saving them,
      4) Handling any special ablation modes (Concat-N),
      5) Extracting or loading keywords,
      6) Optionally filtering keywords by similarity,
      7) Generating a final caption,
      8) Saving the result as a JSON file.
    """
    # Create necessary directories and paths
    filename = make_filename(args.reduce_dims[0:2])
    vindex_pad = str(voxel_index).zfill(6)
    resp_save_path = (
        f"./data/nsd/insilico/{args.subject_names[0]}/"
        f"{args.dataset_name}_{args.max_samples}/"
        f"{args.modality}/{args.modality_hparam}/"
        f"{args.model_name}_{filename}/whole/voxel{vindex_pad}"
    )
    os.makedirs(resp_save_path, exist_ok=True)

    # 1) Temp file handling
    temp_file_path = create_and_check_temp_file(voxel_index, resp_save_path, args)
    if temp_file_path is None:
        return  # Skip due to parallel processing

    # 2) Check existing keywords/caption file
    keywords_file_path = (
        f"{resp_save_path}/keys_and_text_{args.caption_model}_kmodel_{args.keywords_model}_"
        f"{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_"
        f"cmodel_{args.correct_model}.json"
    )
    if check_existing_keywords_file(voxel_index, keywords_file_path, temp_file_path, args):
        return  # Skip

    # 3) Possibly find existing keywords if needed
    keys_and_text = find_existing_keywords_file_if_needed(resp_save_path, keywords_file_path, args)

    # 4) Load top-100 captions, then save them
    captions_list = load_top_captions(resp_save_path, args)
    save_top_captions_to_json(captions_list, resp_save_path, args)

    # 5) Handle Concat-N ablation scenario
    if handle_concat_n(captions_list, keywords_file_path, temp_file_path, args):
        return  # All done in handle_concat_n

    # 6) Extract or load keywords
    masked_sentences, voxel_weight = extract_keywords(
        args,
        captions_list,
        voxel_index,
        weight_index_map,
        layer_weight,
        parser_model,
        parser_tokenizer,
        wte_model,
        stop_words,
        pipe,
        keys_and_text
    )

    # If something went wrong or no valid voxel_weight
    if voxel_weight is None:
        os.remove(temp_file_path)
        return

    # 7) Optionally filter keywords by similarity
    masked_sentences = filter_keywords_by_similarity(
        masked_sentences,
        voxel_weight,
        vl_model,
        args
    )

    # 8) Generate final caption
    best_text = generate_and_postprocess_caption(
        voxel_weight,
        masked_sentences,
        captions_list,
        wte_model,
        tokenizer,
        lm_model,
        vl_model,
        stop_tokens_tensor,
        sub_tokens_tensor,
        reducer_projector,
        args,
        weight_index_map,
        voxel_index
    )

    # 9) Save final JSON results and clean up
    save_final_json(
        keywords_file_path,
        best_text,
        masked_sentences,
        masked_sentences,  # This is the 'filtered_keywords'
        temp_file_path,
        args
    )
    
def main(args):
    """
    Explanation:
    Main pipeline that orchestrates:
    1) Loading the NSDAccess object
    2) Loading various models (parser, LM, VL model, etc.)
    3) Iterating over each subject, retrieving layer weights, volume indices
    4) Iterating over voxels to extract or generate text
    """
    score_root_path = "./data/nsd/encoding"
    nsda = NSDAccess('./data/NSD')

    # Load parser model and tokenizer
    parser_checkpoint = "lizhuang144/flan-t5-base-VG-factual-sg"
    parser_tokenizer, parser_model = load_parser_model_and_tokenizer(parser_checkpoint, args.device)

    # Load WTE and VL model
    wte_model_path = "sentence-transformers/all-MiniLM-L6-v2"
    wte_model = SentenceTransformer(wte_model_path)
    vl_model = load_vl_model(args.device)

    # Load LM model for captioning
    tokenizer, lm_model, stop_tokens_tensor, sub_tokens_tensor = load_lm_model(args)

    # Load LLM pipeline if necessary
    pipe = load_keyword_pipeline(args.keywords_model)

    stop_words = set(stopwords.words('english'))

    # Iterate over each subject
    for subject_name in args.subject_names:
        print(subject_name)
        # Get layer weight and volume indices
        (
            target_best_cv_layer, 
            target_best_cv_layer_num, 
            layer_weight, 
            volume_index, 
            weight_index_map
        ) = get_layer_info_and_weight(subject_name, args, nsda, score_root_path)

        print(f"Best layer: {target_best_cv_layer}")
        print(f"Shape of the layer's weight: {layer_weight.shape}")

        # Load a reducer projector if requested
        reducer_projector = load_reducer_projector(args, subject_name, target_best_cv_layer)
        print(reducer_projector)

        # Process each voxel
        for idx, voxel_index in enumerate(tqdm(volume_index)):
            try:
                process_voxel(
                    voxel_index,
                    args,
                    volume_index,
                    weight_index_map,
                    layer_weight,
                    wte_model,
                    vl_model,
                    parser_model,
                    parser_tokenizer,
                    pipe,
                    stop_words,
                    tokenizer,
                    lm_model,
                    stop_tokens_tensor,
                    sub_tokens_tensor,
                    reducer_projector
                )
            except Exception as e:
                # Any unexpected error in voxel processing
                print(f"Error processing voxel {voxel_index}: {e}")

if __name__ == "__main__":
    """
    Explanation:
    Entry point for argument parsing and calling the main function. 
    You can run this script from the terminal with the appropriate arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        type=str,
        nargs="*",
        required=True,
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_captioner",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--voxel_selection",
        nargs="*",
        type=str,
        required=True,
        help="Selection method of voxels. Implemented type are 'uv' and 'share'."
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=False,
        default="best",
    )
    parser.add_argument(
        "--caption_model",
        type=str,
        required=True,
        help="Name of the captioning model to use."
    )
    parser.add_argument(
        "--keywords_model",
        type=str,
        required=True,
        help="Name of the keyword model to use."
    )
    parser.add_argument(
        "--correct_model",
        type=str,
        required=True,
        help="Name of the correction model to use."
    )
    parser.add_argument(
        "--correct_style",
        type=str,
        choices=["default", "few-shot"],
        help="Name of the correction style to use."
    )
    parser.add_argument(
        "--candidate_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--key_num",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        required=True,
    )
    parser.add_argument(
        "--filter_th",
        type=float,
        required=False,
        default=-1,
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to use."
    )
    parser.add_argument(
        "--only_summarize",
        action="store_true",
        required=False,
        default=False,
    )
    parser.add_argument(
        "--only_caption",
        action="store_true",
        required=False,
        default=False,
    )
    args = parser.parse_args()
    main(args)
